{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp models.patchtst"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# PatchTST"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The PatchTST model is an efficient Transformer-based model for multivariate time series forecasting.\n",
    "\n",
    "It is based on two key components:\n",
    "- segmentation of time series into windows (patches) which are served as input tokens to Transformer\n",
    "- channel-independence. where each channel contains a single univariate time series."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**References**<br>\n",
    "- [Nie, Y., Nguyen, N. H., Sinthong, P., & Kalagnanam, J. (2022). \"A Time Series is Worth 64 Words: Long-term Forecasting with Transformers\"](https://arxiv.org/pdf/2211.14730.pdf)<br>"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![Figure 1. PatchTST.](imgs_models/patchtst.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "import math\n",
    "import numpy as np\n",
    "from typing import Optional #, Any, Tuple\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from neuralforecast.common._base_model import BaseModel\n",
    "from neuralforecast.common._modules import RevIN\n",
    "\n",
    "from neuralforecast.losses.pytorch import MAE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "import logging\n",
    "import warnings\n",
    "from fastcore.test import test_eq\n",
    "from nbdev.showdoc import show_doc\n",
    "from neuralforecast.common._model_checks import check_model"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Backbone"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Auxiliary Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class Transpose(nn.Module):\n",
    "    \"\"\"\n",
    "    Transpose\n",
    "    \"\"\"\n",
    "    def __init__(self, *dims, contiguous=False): \n",
    "        super().__init__()\n",
    "        self.dims, self.contiguous = dims, contiguous\n",
    "    def forward(self, x):\n",
    "        if self.contiguous: return x.transpose(*self.dims).contiguous()\n",
    "        else: return x.transpose(*self.dims)\n",
    "\n",
    "def get_activation_fn(activation):\n",
    "    if callable(activation): return activation()\n",
    "    elif activation.lower() == \"relu\": return nn.ReLU()\n",
    "    elif activation.lower() == \"gelu\": return nn.GELU()\n",
    "    raise ValueError(f'{activation} is not available. You can use \"relu\", \"gelu\", or a callable') "
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Positional Encoding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def PositionalEncoding(q_len, hidden_size, normalize=True):\n",
    "    pe = torch.zeros(q_len, hidden_size)\n",
    "    position = torch.arange(0, q_len).unsqueeze(1)\n",
    "    div_term = torch.exp(torch.arange(0, hidden_size, 2) * -(math.log(10000.0) / hidden_size))\n",
    "    pe[:, 0::2] = torch.sin(position * div_term)\n",
    "    pe[:, 1::2] = torch.cos(position * div_term)\n",
    "    if normalize:\n",
    "        pe = pe - pe.mean()\n",
    "        pe = pe / (pe.std() * 10)\n",
    "    return pe\n",
    "\n",
    "SinCosPosEncoding = PositionalEncoding\n",
    "\n",
    "def Coord2dPosEncoding(q_len, hidden_size, exponential=False, normalize=True, eps=1e-3):\n",
    "    x = .5 if exponential else 1\n",
    "    i = 0\n",
    "    for i in range(100):\n",
    "        cpe = 2 * (torch.linspace(0, 1, q_len).reshape(-1, 1) ** x) * (torch.linspace(0, 1, hidden_size).reshape(1, -1) ** x) - 1\n",
    "        if abs(cpe.mean()) <= eps: break\n",
    "        elif cpe.mean() > eps: x += .001\n",
    "        else: x -= .001\n",
    "        i += 1\n",
    "    if normalize:\n",
    "        cpe = cpe - cpe.mean()\n",
    "        cpe = cpe / (cpe.std() * 10)\n",
    "    return cpe\n",
    "\n",
    "def Coord1dPosEncoding(q_len, exponential=False, normalize=True):\n",
    "    cpe = (2 * (torch.linspace(0, 1, q_len).reshape(-1, 1)**(.5 if exponential else 1)) - 1)\n",
    "    if normalize:\n",
    "        cpe = cpe - cpe.mean()\n",
    "        cpe = cpe / (cpe.std() * 10)\n",
    "    return cpe\n",
    "\n",
    "def positional_encoding(pe, learn_pe, q_len, hidden_size):\n",
    "    # Positional encoding\n",
    "    if pe == None:\n",
    "        W_pos = torch.empty((q_len, hidden_size)) # pe = None and learn_pe = False can be used to measure impact of pe\n",
    "        nn.init.uniform_(W_pos, -0.02, 0.02)\n",
    "        learn_pe = False\n",
    "    elif pe == 'zero':\n",
    "        W_pos = torch.empty((q_len, 1))\n",
    "        nn.init.uniform_(W_pos, -0.02, 0.02)\n",
    "    elif pe == 'zeros':\n",
    "        W_pos = torch.empty((q_len, hidden_size))\n",
    "        nn.init.uniform_(W_pos, -0.02, 0.02)\n",
    "    elif pe == 'normal' or pe == 'gauss':\n",
    "        W_pos = torch.zeros((q_len, 1))\n",
    "        torch.nn.init.normal_(W_pos, mean=0.0, std=0.1)\n",
    "    elif pe == 'uniform':\n",
    "        W_pos = torch.zeros((q_len, 1))\n",
    "        nn.init.uniform_(W_pos, a=0.0, b=0.1)\n",
    "    elif pe == 'lin1d': W_pos = Coord1dPosEncoding(q_len, exponential=False, normalize=True)\n",
    "    elif pe == 'exp1d': W_pos = Coord1dPosEncoding(q_len, exponential=True, normalize=True)\n",
    "    elif pe == 'lin2d': W_pos = Coord2dPosEncoding(q_len, hidden_size, exponential=False, normalize=True)\n",
    "    elif pe == 'exp2d': W_pos = Coord2dPosEncoding(q_len, hidden_size, exponential=True, normalize=True)\n",
    "    elif pe == 'sincos': W_pos = PositionalEncoding(q_len, hidden_size, normalize=True)\n",
    "    else: raise ValueError(f\"{pe} is not a valid pe (positional encoder. Available types: 'gauss'=='normal', \\\n",
    "        'zeros', 'zero', uniform', 'lin1d', 'exp1d', 'lin2d', 'exp2d', 'sincos', None.)\")\n",
    "    return nn.Parameter(W_pos, requires_grad=learn_pe)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Encoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class PatchTST_backbone(nn.Module):\n",
    "    \"\"\"\n",
    "    PatchTST_backbone\n",
    "    \"\"\"      \n",
    "    def __init__(self, c_in:int, c_out:int, input_size:int, h:int, patch_len:int, stride:int, max_seq_len:Optional[int]=1024, \n",
    "                 n_layers:int=3, hidden_size=128, n_heads=16, d_k:Optional[int]=None, d_v:Optional[int]=None,\n",
    "                 linear_hidden_size:int=256, norm:str='BatchNorm', attn_dropout:float=0., dropout:float=0., act:str=\"gelu\", key_padding_mask:str='auto',\n",
    "                 padding_var:Optional[int]=None, attn_mask:Optional[torch.Tensor]=None, res_attention:bool=True, pre_norm:bool=False, store_attn:bool=False,\n",
    "                 pe:str='zeros', learn_pe:bool=True, fc_dropout:float=0., head_dropout = 0, padding_patch = None,\n",
    "                 pretrain_head:bool=False, head_type = 'flatten', individual = False, revin = True, affine = True, subtract_last = False):\n",
    "        \n",
    "        super().__init__()\n",
    "\n",
    "        # RevIn\n",
    "        self.revin = revin\n",
    "        if self.revin: self.revin_layer = RevIN(c_in, affine=affine, subtract_last=subtract_last)\n",
    "\n",
    "        # Patching\n",
    "        self.patch_len = patch_len\n",
    "        self.stride = stride\n",
    "        self.padding_patch = padding_patch\n",
    "        patch_num = int((input_size - patch_len)/stride + 1)\n",
    "        if padding_patch == 'end': # can be modified to general case\n",
    "            self.padding_patch_layer = nn.ReplicationPad1d((0, stride)) \n",
    "            patch_num += 1\n",
    "\n",
    "        # Backbone \n",
    "        self.backbone = TSTiEncoder(c_in, patch_num=patch_num, patch_len=patch_len, max_seq_len=max_seq_len,\n",
    "                                n_layers=n_layers, hidden_size=hidden_size, n_heads=n_heads, d_k=d_k, d_v=d_v, linear_hidden_size=linear_hidden_size,\n",
    "                                attn_dropout=attn_dropout, dropout=dropout, act=act, key_padding_mask=key_padding_mask, padding_var=padding_var,\n",
    "                                attn_mask=attn_mask, res_attention=res_attention, pre_norm=pre_norm, store_attn=store_attn,\n",
    "                                pe=pe, learn_pe=learn_pe)\n",
    "\n",
    "        # Head\n",
    "        self.head_nf = hidden_size * patch_num\n",
    "        self.n_vars = c_in\n",
    "        self.c_out = c_out\n",
    "        self.pretrain_head = pretrain_head\n",
    "        self.head_type = head_type\n",
    "        self.individual = individual\n",
    "\n",
    "        if self.pretrain_head: \n",
    "            self.head = self.create_pretrain_head(self.head_nf, c_in, fc_dropout) # custom head passed as a partial func with all its kwargs\n",
    "        elif head_type == 'flatten': \n",
    "            self.head = Flatten_Head(self.individual, self.n_vars, self.head_nf, h, c_out, head_dropout=head_dropout)\n",
    "\n",
    "    def forward(self, z):                                                                   # z: [bs x nvars x seq_len]\n",
    "        # norm\n",
    "        if self.revin: \n",
    "            z = z.permute(0,2,1)\n",
    "            z = self.revin_layer(z, 'norm')\n",
    "            z = z.permute(0,2,1)\n",
    "\n",
    "        # do patching\n",
    "        if self.padding_patch == 'end':\n",
    "            z = self.padding_patch_layer(z)\n",
    "        z = z.unfold(dimension=-1, size=self.patch_len, step=self.stride)                   # z: [bs x nvars x patch_num x patch_len]\n",
    "        z = z.permute(0,1,3,2)                                                              # z: [bs x nvars x patch_len x patch_num]\n",
    "\n",
    "        # model\n",
    "        z = self.backbone(z)                                                                # z: [bs x nvars x hidden_size x patch_num]\n",
    "        z = self.head(z)                                                                    # z: [bs x nvars x h] \n",
    "\n",
    "        # denorm\n",
    "        if self.revin:\n",
    "            z = z.permute(0,2,1)\n",
    "            z = self.revin_layer(z, 'denorm')\n",
    "            z = z.permute(0,2,1)\n",
    "        return z\n",
    "    \n",
    "    def create_pretrain_head(self, head_nf, vars, dropout):\n",
    "        return nn.Sequential(nn.Dropout(dropout),\n",
    "                    nn.Conv1d(head_nf, vars, 1)\n",
    "                    )\n",
    "\n",
    "\n",
    "class Flatten_Head(nn.Module):\n",
    "    \"\"\"\n",
    "    Flatten_Head\n",
    "    \"\"\"        \n",
    "    def __init__(self, individual, n_vars, nf, h, c_out, head_dropout=0):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.individual = individual\n",
    "        self.n_vars = n_vars\n",
    "        self.c_out = c_out\n",
    "        \n",
    "        if self.individual:\n",
    "            self.linears = nn.ModuleList()\n",
    "            self.dropouts = nn.ModuleList()\n",
    "            self.flattens = nn.ModuleList()\n",
    "            for i in range(self.n_vars):\n",
    "                self.flattens.append(nn.Flatten(start_dim=-2))\n",
    "                self.linears.append(nn.Linear(nf, h*c_out))\n",
    "                self.dropouts.append(nn.Dropout(head_dropout))\n",
    "        else:\n",
    "            self.flatten = nn.Flatten(start_dim=-2)\n",
    "            self.linear = nn.Linear(nf, h*c_out)\n",
    "            self.dropout = nn.Dropout(head_dropout)\n",
    "            \n",
    "    def forward(self, x):                                 # x: [bs x nvars x hidden_size x patch_num]\n",
    "        if self.individual:\n",
    "            x_out = []\n",
    "            for i in range(self.n_vars):\n",
    "                z = self.flattens[i](x[:,i,:,:])          # z: [bs x hidden_size * patch_num]\n",
    "                z = self.linears[i](z)                    # z: [bs x h]\n",
    "                z = self.dropouts[i](z)\n",
    "                x_out.append(z)\n",
    "            x = torch.stack(x_out, dim=1)                 # x: [bs x nvars x h]\n",
    "        else:\n",
    "            x = self.flatten(x)\n",
    "            x = self.linear(x)\n",
    "            x = self.dropout(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "class TSTiEncoder(nn.Module):  #i means channel-independent\n",
    "    \"\"\"\n",
    "    TSTiEncoder\n",
    "    \"\"\"      \n",
    "    def __init__(self, c_in, patch_num, patch_len, max_seq_len=1024,\n",
    "                 n_layers=3, hidden_size=128, n_heads=16, d_k=None, d_v=None,\n",
    "                 linear_hidden_size=256, norm='BatchNorm', attn_dropout=0., dropout=0., act=\"gelu\", store_attn=False,\n",
    "                 key_padding_mask='auto', padding_var=None, attn_mask=None, res_attention=True, pre_norm=False,\n",
    "                 pe='zeros', learn_pe=True):\n",
    "        \n",
    "        \n",
    "        super().__init__()\n",
    "        \n",
    "        self.patch_num = patch_num\n",
    "        self.patch_len = patch_len\n",
    "        \n",
    "        # Input encoding\n",
    "        q_len = patch_num\n",
    "        self.W_P = nn.Linear(patch_len, hidden_size)        # Eq 1: projection of feature vectors onto a d-dim vector space\n",
    "        self.seq_len = q_len\n",
    "\n",
    "        # Positional encoding\n",
    "        self.W_pos = positional_encoding(pe, learn_pe, q_len, hidden_size)\n",
    "\n",
    "        # Residual dropout\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "\n",
    "        # Encoder\n",
    "        self.encoder = TSTEncoder(q_len, hidden_size, n_heads, d_k=d_k, d_v=d_v, linear_hidden_size=linear_hidden_size, norm=norm, attn_dropout=attn_dropout, dropout=dropout,\n",
    "                                   pre_norm=pre_norm, activation=act, res_attention=res_attention, n_layers=n_layers, store_attn=store_attn)\n",
    "        \n",
    "    def forward(self, x) -> torch.Tensor:                                        # x: [bs x nvars x patch_len x patch_num]\n",
    "        \n",
    "        n_vars = x.shape[1]\n",
    "        # Input encoding\n",
    "        x = x.permute(0,1,3,2)                                                   # x: [bs x nvars x patch_num x patch_len]\n",
    "        x = self.W_P(x)                                                          # x: [bs x nvars x patch_num x hidden_size]\n",
    "\n",
    "        u = torch.reshape(x, (x.shape[0]*x.shape[1],x.shape[2],x.shape[3]))      # u: [bs * nvars x patch_num x hidden_size]\n",
    "        u = self.dropout(u + self.W_pos)                                         # u: [bs * nvars x patch_num x hidden_size]\n",
    "\n",
    "        # Encoder\n",
    "        z = self.encoder(u)                                                      # z: [bs * nvars x patch_num x hidden_size]\n",
    "        z = torch.reshape(z, (-1,n_vars,z.shape[-2],z.shape[-1]))                # z: [bs x nvars x patch_num x hidden_size]\n",
    "        z = z.permute(0,1,3,2)                                                   # z: [bs x nvars x hidden_size x patch_num]\n",
    "        \n",
    "        return z    \n",
    "            \n",
    "\n",
    "class TSTEncoder(nn.Module):\n",
    "    \"\"\"\n",
    "    TSTEncoder\n",
    "    \"\"\"     \n",
    "    def __init__(self, q_len, hidden_size, n_heads, d_k=None, d_v=None, linear_hidden_size=None, \n",
    "                        norm='BatchNorm', attn_dropout=0., dropout=0., activation='gelu',\n",
    "                        res_attention=False, n_layers=1, pre_norm=False, store_attn=False):\n",
    "        super().__init__()\n",
    "\n",
    "        self.layers = nn.ModuleList([TSTEncoderLayer(q_len, hidden_size, n_heads=n_heads, d_k=d_k, d_v=d_v, linear_hidden_size=linear_hidden_size, norm=norm,\n",
    "                                                      attn_dropout=attn_dropout, dropout=dropout,\n",
    "                                                      activation=activation, res_attention=res_attention,\n",
    "                                                      pre_norm=pre_norm, store_attn=store_attn) for i in range(n_layers)])\n",
    "        self.res_attention = res_attention\n",
    "\n",
    "    def forward(self, src:torch.Tensor, key_padding_mask:Optional[torch.Tensor]=None, attn_mask:Optional[torch.Tensor]=None):\n",
    "        output = src\n",
    "        scores = None\n",
    "        if self.res_attention:\n",
    "            for mod in self.layers: output, scores = mod(output, prev=scores, key_padding_mask=key_padding_mask, attn_mask=attn_mask)\n",
    "            return output\n",
    "        else:\n",
    "            for mod in self.layers: output = mod(output, key_padding_mask=key_padding_mask, attn_mask=attn_mask)\n",
    "            return output\n",
    "\n",
    "\n",
    "class TSTEncoderLayer(nn.Module):\n",
    "    \"\"\"\n",
    "    TSTEncoderLayer\n",
    "    \"\"\"      \n",
    "    def __init__(self, q_len, hidden_size, n_heads, d_k=None, d_v=None, linear_hidden_size=256, store_attn=False,\n",
    "                 norm='BatchNorm', attn_dropout=0, dropout=0., bias=True, activation=\"gelu\", res_attention=False, pre_norm=False):\n",
    "        super().__init__()\n",
    "        assert not hidden_size%n_heads, f\"hidden_size ({hidden_size}) must be divisible by n_heads ({n_heads})\"\n",
    "        d_k = hidden_size // n_heads if d_k is None else d_k\n",
    "        d_v = hidden_size // n_heads if d_v is None else d_v\n",
    "\n",
    "        # Multi-Head attention\n",
    "        self.res_attention = res_attention\n",
    "        self.self_attn = _MultiheadAttention(hidden_size, n_heads, d_k, d_v, attn_dropout=attn_dropout,\n",
    "                                             proj_dropout=dropout, res_attention=res_attention)\n",
    "\n",
    "        # Add & Norm\n",
    "        self.dropout_attn = nn.Dropout(dropout)\n",
    "        if \"batch\" in norm.lower():\n",
    "            self.norm_attn = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(hidden_size), Transpose(1,2))\n",
    "        else:\n",
    "            self.norm_attn = nn.LayerNorm(hidden_size)\n",
    "\n",
    "        # Position-wise Feed-Forward\n",
    "        self.ff = nn.Sequential(nn.Linear(hidden_size, linear_hidden_size, bias=bias),\n",
    "                                get_activation_fn(activation),\n",
    "                                nn.Dropout(dropout),\n",
    "                                nn.Linear(linear_hidden_size, hidden_size, bias=bias))\n",
    "\n",
    "        # Add & Norm\n",
    "        self.dropout_ffn = nn.Dropout(dropout)\n",
    "        if \"batch\" in norm.lower():\n",
    "            self.norm_ffn = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(hidden_size), Transpose(1,2))\n",
    "        else:\n",
    "            self.norm_ffn = nn.LayerNorm(hidden_size)\n",
    "\n",
    "        self.pre_norm = pre_norm\n",
    "        self.store_attn = store_attn\n",
    "\n",
    "    def forward(self, src:torch.Tensor, prev:Optional[torch.Tensor]=None,\n",
    "                key_padding_mask:Optional[torch.Tensor]=None,\n",
    "                attn_mask:Optional[torch.Tensor]=None): # -> Tuple[torch.Tensor, Any]:\n",
    "\n",
    "        # Multi-Head attention sublayer\n",
    "        if self.pre_norm:\n",
    "            src = self.norm_attn(src)\n",
    "        ## Multi-Head attention\n",
    "        if self.res_attention:\n",
    "            src2, attn, scores = self.self_attn(src, src, src, prev,\n",
    "                                                key_padding_mask=key_padding_mask, attn_mask=attn_mask)\n",
    "        else:\n",
    "            src2, attn = self.self_attn(src, src, src, key_padding_mask=key_padding_mask, attn_mask=attn_mask)\n",
    "        if self.store_attn:\n",
    "            self.attn = attn\n",
    "        ## Add & Norm\n",
    "        src = src + self.dropout_attn(src2) # Add: residual connection with residual dropout\n",
    "        if not self.pre_norm:\n",
    "            src = self.norm_attn(src)\n",
    "\n",
    "        # Feed-forward sublayer\n",
    "        if self.pre_norm:\n",
    "            src = self.norm_ffn(src)\n",
    "        ## Position-wise Feed-Forward\n",
    "        src2 = self.ff(src)\n",
    "        ## Add & Norm\n",
    "        src = src + self.dropout_ffn(src2) # Add: residual connection with residual dropout\n",
    "        if not self.pre_norm:\n",
    "            src = self.norm_ffn(src)\n",
    "\n",
    "        if self.res_attention:\n",
    "            return src, scores\n",
    "        else:\n",
    "            return src\n",
    "\n",
    "\n",
    "class _MultiheadAttention(nn.Module):\n",
    "    \"\"\"\n",
    "    _MultiheadAttention\n",
    "    \"\"\"       \n",
    "    def __init__(self, hidden_size, n_heads, d_k=None, d_v=None,\n",
    "                 res_attention=False, attn_dropout=0., proj_dropout=0., qkv_bias=True, lsa=False):\n",
    "        \"\"\"\n",
    "        Multi Head Attention Layer\n",
    "        Input shape:\n",
    "            Q:       [batch_size (bs) x max_q_len x hidden_size]\n",
    "            K, V:    [batch_size (bs) x q_len x hidden_size]\n",
    "            mask:    [q_len x q_len]\n",
    "        \"\"\"\n",
    "        super().__init__()\n",
    "        d_k = hidden_size // n_heads if d_k is None else d_k\n",
    "        d_v = hidden_size // n_heads if d_v is None else d_v\n",
    "\n",
    "        self.n_heads, self.d_k, self.d_v = n_heads, d_k, d_v\n",
    "\n",
    "        self.W_Q = nn.Linear(hidden_size, d_k * n_heads, bias=qkv_bias)\n",
    "        self.W_K = nn.Linear(hidden_size, d_k * n_heads, bias=qkv_bias)\n",
    "        self.W_V = nn.Linear(hidden_size, d_v * n_heads, bias=qkv_bias)\n",
    "\n",
    "        # Scaled Dot-Product Attention (multiple heads)\n",
    "        self.res_attention = res_attention\n",
    "        self.sdp_attn = _ScaledDotProductAttention(hidden_size, n_heads, attn_dropout=attn_dropout,\n",
    "                                                   res_attention=self.res_attention, lsa=lsa)\n",
    "\n",
    "        # Poject output\n",
    "        self.to_out = nn.Sequential(nn.Linear(n_heads * d_v, hidden_size), nn.Dropout(proj_dropout))\n",
    "\n",
    "    def forward(self, Q:torch.Tensor, K:Optional[torch.Tensor]=None, V:Optional[torch.Tensor]=None, prev:Optional[torch.Tensor]=None,\n",
    "                key_padding_mask:Optional[torch.Tensor]=None, attn_mask:Optional[torch.Tensor]=None):\n",
    "\n",
    "        bs = Q.size(0)\n",
    "        if K is None: K = Q\n",
    "        if V is None: V = Q\n",
    "\n",
    "        # Linear (+ split in multiple heads)\n",
    "        q_s = self.W_Q(Q).view(bs, -1, self.n_heads, self.d_k).transpose(1,2)       # q_s    : [bs x n_heads x max_q_len x d_k]\n",
    "        k_s = self.W_K(K).view(bs, -1, self.n_heads, self.d_k).permute(0,2,3,1)     # k_s    : [bs x n_heads x d_k x q_len] - transpose(1,2) + transpose(2,3)\n",
    "        v_s = self.W_V(V).view(bs, -1, self.n_heads, self.d_v).transpose(1,2)       # v_s    : [bs x n_heads x q_len x d_v]\n",
    "\n",
    "        # Apply Scaled Dot-Product Attention (multiple heads)\n",
    "        if self.res_attention:\n",
    "            output, attn_weights, attn_scores = self.sdp_attn(q_s, k_s, v_s,\n",
    "                                                    prev=prev, key_padding_mask=key_padding_mask, attn_mask=attn_mask)\n",
    "        else:\n",
    "            output, attn_weights = self.sdp_attn(q_s, k_s, v_s, key_padding_mask=key_padding_mask, attn_mask=attn_mask)\n",
    "        # output: [bs x n_heads x q_len x d_v], attn: [bs x n_heads x q_len x q_len], scores: [bs x n_heads x max_q_len x q_len]\n",
    "\n",
    "        # back to the original inputs dimensions\n",
    "        output = output.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * self.d_v) # output: [bs x q_len x n_heads * d_v]\n",
    "        output = self.to_out(output)\n",
    "\n",
    "        if self.res_attention: return output, attn_weights, attn_scores\n",
    "        else: return output, attn_weights\n",
    "\n",
    "\n",
    "class _ScaledDotProductAttention(nn.Module):\n",
    "    \"\"\"\n",
    "    Scaled Dot-Product Attention module (Attention is all you need by Vaswani et al., 2017) with optional residual attention from previous layer\n",
    "    (Realformer: Transformer likes residual attention by He et al, 2020) and locality self sttention (Vision Transformer for Small-Size Datasets\n",
    "    by Lee et al, 2021)\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, hidden_size, n_heads, attn_dropout=0., res_attention=False, lsa=False):\n",
    "        super().__init__()\n",
    "        self.attn_dropout = nn.Dropout(attn_dropout)\n",
    "        self.res_attention = res_attention\n",
    "        head_dim = hidden_size // n_heads\n",
    "        self.scale = nn.Parameter(torch.tensor(head_dim ** -0.5), requires_grad=lsa)\n",
    "        self.lsa = lsa\n",
    "\n",
    "    def forward(self, q:torch.Tensor, k:torch.Tensor, v:torch.Tensor,\n",
    "                prev:Optional[torch.Tensor]=None, key_padding_mask:Optional[torch.Tensor]=None,\n",
    "                attn_mask:Optional[torch.Tensor]=None):\n",
    "        '''\n",
    "        Input shape:\n",
    "            q               : [bs x n_heads x max_q_len x d_k]\n",
    "            k               : [bs x n_heads x d_k x seq_len]\n",
    "            v               : [bs x n_heads x seq_len x d_v]\n",
    "            prev            : [bs x n_heads x q_len x seq_len]\n",
    "            key_padding_mask: [bs x seq_len]\n",
    "            attn_mask       : [1 x seq_len x seq_len]\n",
    "        Output shape:\n",
    "            output:  [bs x n_heads x q_len x d_v]\n",
    "            attn   : [bs x n_heads x q_len x seq_len]\n",
    "            scores : [bs x n_heads x q_len x seq_len]\n",
    "        '''\n",
    "        if not self.res_attention:\n",
    "            # Use torch's built-in flash attention for efficient computation\n",
    "            # Note: This will not return attention weights/scores\n",
    "            # The shapes of q, k, v must be: [batch, n_heads, seq_len, head_dim]\n",
    "            # Reshape q, k, v into [batch*n_heads, seq_len, head_dim] as required by torch.nn.functional.scaled_dot_product_attention\n",
    "            bs, n_heads, seq_len, head_dim = q.shape\n",
    "            q_ = q.reshape(bs * n_heads, seq_len, head_dim)\n",
    "            k_ = k.permute(0, 1, 3, 2).reshape(bs * n_heads, seq_len, head_dim)\n",
    "            v_ = v.reshape(bs * n_heads, seq_len, head_dim)\n",
    "            # If attn_mask exists, convert it to the appropriate format for flash attention (e.g. [batch*n_heads, seq_len, seq_len])\n",
    "            if attn_mask is not None:\n",
    "                attn_mask = attn_mask.repeat(bs * n_heads, 1, 1)\n",
    "            output = F.scaled_dot_product_attention(\n",
    "                q_, k_, v_, attn_mask=attn_mask,\n",
    "                dropout_p=self.attn_dropout.p, is_causal=False)\n",
    "            # Restore the original shape\n",
    "            output = output.reshape(bs, n_heads, seq_len, head_dim)\n",
    "            return output, None\n",
    "        else:\n",
    "            # Scaled MatMul (q, k) - similarity scores for all pairs of positions in an input sequence\n",
    "            attn_scores = torch.matmul(q, k) * self.scale      # attn_scores : [bs x n_heads x max_q_len x q_len]\n",
    "\n",
    "            # Add pre-softmax attention scores from the previous layer (optional)\n",
    "            if prev is not None: attn_scores = attn_scores + prev\n",
    "\n",
    "            # Attention mask (optional)\n",
    "            if attn_mask is not None:                                     # attn_mask with shape [q_len x seq_len] - only used when q_len == seq_len\n",
    "                if attn_mask.dtype == torch.bool:\n",
    "                    attn_scores.masked_fill_(attn_mask, -np.inf)\n",
    "                else:\n",
    "                    attn_scores += attn_mask\n",
    "\n",
    "            # Key padding mask (optional)\n",
    "            if key_padding_mask is not None:                              # mask with shape [bs x q_len] (only when max_w_len == q_len)\n",
    "                attn_scores.masked_fill_(key_padding_mask.unsqueeze(1).unsqueeze(2), -np.inf)\n",
    "\n",
    "            # normalize the attention weights\n",
    "            attn_weights = F.softmax(attn_scores, dim=-1)                 # attn_weights   : [bs x n_heads x max_q_len x q_len]\n",
    "            attn_weights = self.attn_dropout(attn_weights)\n",
    "\n",
    "            # compute the new values given the attention weights\n",
    "            output = torch.matmul(attn_weights, v)                        # output: [bs x n_heads x max_q_len x d_v]\n",
    "\n",
    "            return output, attn_weights, attn_scores\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class PatchTST(BaseModel):\n",
    "    \"\"\" PatchTST\n",
    "\n",
    "    The PatchTST model is an efficient Transformer-based model for multivariate time series forecasting.\n",
    "\n",
    "    It is based on two key components:\n",
    "    - segmentation of time series into windows (patches) which are served as input tokens to Transformer\n",
    "    - channel-independence, where each channel contains a single univariate time series.\n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `h`: int, Forecast horizon. <br>\n",
    "    `input_size`: int, autorregresive inputs size, y=[1,2,3,4] input_size=2 -> y_[t-2:t]=[1,2].<br>\n",
    "    `stat_exog_list`: str list, static exogenous columns.<br>\n",
    "    `hist_exog_list`: str list, historic exogenous columns.<br>\n",
    "    `futr_exog_list`: str list, future exogenous columns.<br>\n",
    "    `exclude_insample_y`: bool=False, the model skips the autoregressive features y[t-input_size:t] if True.<br>\n",
    "    `encoder_layers`: int, number of layers for encoder.<br>\n",
    "    `n_heads`: int=16, number of multi-head's attention.<br>\n",
    "    `hidden_size`: int=128, units of embeddings and encoders.<br>\n",
    "    `linear_hidden_size`: int=256, units of linear layer.<br>\n",
    "    `dropout`: float=0.1, dropout rate for residual connection.<br>\n",
    "    `fc_dropout`: float=0.1, dropout rate for linear layer.<br>\n",
    "    `head_dropout`: float=0.1, dropout rate for Flatten head layer.<br>\n",
    "    `attn_dropout`: float=0.1, dropout rate for attention layer.<br>\n",
    "    `patch_len`: int=32, length of patch. Note: patch_len = min(patch_len, input_size + stride).<br>\n",
    "    `stride`: int=16, stride of patch.<br>\n",
    "    `revin`: bool=True, bool to use RevIn.<br>\n",
    "    `revin_affine`: bool=False, bool to use affine in RevIn.<br>\n",
    "    `revin_subtract_last`: bool=False, bool to use substract last in RevIn.<br>\n",
    "    `activation`: str='ReLU', activation from ['gelu','relu'].<br>\n",
    "    `res_attention`: bool=False, bool to use residual attention.<br>\n",
    "    `batch_normalization`: bool=False, bool to use batch normalization.<br>\n",
    "    `learn_pos_embed`: bool=True, bool to learn positional embedding.<br>\n",
    "    `loss`: PyTorch module, instantiated train loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).<br>\n",
    "    `valid_loss`: PyTorch module=`loss`, instantiated valid loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).<br>\n",
    "    `max_steps`: int=1000, maximum number of training steps.<br>\n",
    "    `learning_rate`: float=1e-3, Learning rate between (0, 1).<br>\n",
    "    `num_lr_decays`: int=-1, Number of learning rate decays, evenly distributed across max_steps.<br>\n",
    "    `early_stop_patience_steps`: int=-1, Number of validation iterations before early stopping.<br>\n",
    "    `val_check_steps`: int=100, Number of training steps between every validation loss check.<br>\n",
    "    `batch_size`: int=32, number of different series in each batch.<br>\n",
    "    `valid_batch_size`: int=None, number of different series in each validation and test batch, if None uses batch_size.<br>\n",
    "    `windows_batch_size`: int=1024, number of windows to sample in each training batch, default uses all.<br>\n",
    "    `inference_windows_batch_size`: int=1024, number of windows to sample in each inference batch.<br>\n",
    "    `start_padding_enabled`: bool=False, if True, the model will pad the time series with zeros at the beginning, by input size.<br>\n",
    "    `step_size`: int=1, step size between each window of temporal data.<br>\n",
    "    `scaler_type`: str='identity', type of scaler for temporal inputs normalization see [temporal scalers](https://nixtla.github.io/neuralforecast/common.scalers.html).<br>\n",
    "    `random_seed`: int, random_seed for pytorch initializer and numpy generators.<br>\n",
    "    `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.<br>\n",
    "    `alias`: str, optional,  Custom name of the model.<br>\n",
    "    `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).<br>\n",
    "    `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.<br>\n",
    "    `lr_scheduler`: Subclass of 'torch.optim.lr_scheduler.LRScheduler', optional, user specified lr_scheduler instead of the default choice (StepLR).<br>\n",
    "    `lr_scheduler_kwargs`: dict, optional, list of parameters used by the user specified `lr_scheduler`.<br>    \n",
    "    `dataloader_kwargs`: dict, optional, list of parameters passed into the PyTorch Lightning dataloader by the `TimeSeriesDataLoader`. <br>\n",
    "    `**trainer_kwargs`: int,  keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).<br>    \n",
    "\n",
    "    **References:**<br>\n",
    "    -[Nie, Y., Nguyen, N. H., Sinthong, P., & Kalagnanam, J. (2022). \"A Time Series is Worth 64 Words: Long-term Forecasting with Transformers\"](https://arxiv.org/pdf/2211.14730.pdf)\n",
    "    \"\"\"\n",
    "    # Class attributes\n",
    "    EXOGENOUS_FUTR = False\n",
    "    EXOGENOUS_HIST = False\n",
    "    EXOGENOUS_STAT = False\n",
    "    MULTIVARIATE = False    # If the model produces multivariate forecasts (True) or univariate (False)\n",
    "    RECURRENT = False       # If the model produces forecasts recursively (True) or direct (False)\n",
    "\n",
    "    def __init__(self,\n",
    "                 h,\n",
    "                 input_size,\n",
    "                 stat_exog_list = None,\n",
    "                 hist_exog_list = None,\n",
    "                 futr_exog_list = None,\n",
    "                 exclude_insample_y = False,\n",
    "                 encoder_layers: int = 3,\n",
    "                 n_heads: int = 16,\n",
    "                 hidden_size: int = 128,\n",
    "                 linear_hidden_size: int = 256,\n",
    "                 dropout: float = 0.2,\n",
    "                 fc_dropout: float = 0.2,\n",
    "                 head_dropout: float = 0.0,\n",
    "                 attn_dropout: float = 0.,\n",
    "                 patch_len: int = 16,\n",
    "                 stride: int = 8,\n",
    "                 revin: bool = True,\n",
    "                 revin_affine: bool = False,\n",
    "                 revin_subtract_last: bool = True,\n",
    "                 activation: str = \"gelu\",\n",
    "                 res_attention: bool = True, \n",
    "                 batch_normalization: bool = False,\n",
    "                 learn_pos_embed: bool = True,\n",
    "                 loss = MAE(),\n",
    "                 valid_loss = None,\n",
    "                 max_steps: int = 5000,\n",
    "                 learning_rate: float = 1e-4,\n",
    "                 num_lr_decays: int = -1,\n",
    "                 early_stop_patience_steps: int =-1,\n",
    "                 val_check_steps: int = 100,\n",
    "                 batch_size: int = 32,\n",
    "                 valid_batch_size: Optional[int] = None,\n",
    "                 windows_batch_size = 1024,\n",
    "                 inference_windows_batch_size: int = 1024,\n",
    "                 start_padding_enabled = False,\n",
    "                 step_size: int = 1,\n",
    "                 scaler_type: str = 'identity',\n",
    "                 random_seed: int = 1,\n",
    "                 drop_last_loader: bool = False,\n",
    "                 alias: Optional[str] = None,\n",
    "                 optimizer = None,\n",
    "                 optimizer_kwargs = None,\n",
    "                 lr_scheduler = None,\n",
    "                 lr_scheduler_kwargs = None,\n",
    "                 dataloader_kwargs = None,\n",
    "                 **trainer_kwargs):\n",
    "        super(PatchTST, self).__init__(h=h,\n",
    "                                       input_size=input_size,\n",
    "                                       stat_exog_list=stat_exog_list,\n",
    "                                       hist_exog_list=hist_exog_list,\n",
    "                                       futr_exog_list = futr_exog_list,\n",
    "                                       exclude_insample_y = exclude_insample_y,\n",
    "                                       loss=loss,\n",
    "                                       valid_loss=valid_loss,\n",
    "                                       max_steps=max_steps,\n",
    "                                       learning_rate=learning_rate,\n",
    "                                       num_lr_decays=num_lr_decays,\n",
    "                                       early_stop_patience_steps=early_stop_patience_steps,\n",
    "                                       val_check_steps=val_check_steps,\n",
    "                                       batch_size=batch_size,\n",
    "                                       valid_batch_size=valid_batch_size,\n",
    "                                       windows_batch_size=windows_batch_size,\n",
    "                                       inference_windows_batch_size=inference_windows_batch_size,\n",
    "                                       start_padding_enabled=start_padding_enabled,\n",
    "                                       step_size=step_size,\n",
    "                                       scaler_type=scaler_type,\n",
    "                                       random_seed=random_seed,\n",
    "                                       drop_last_loader=drop_last_loader,\n",
    "                                       alias=alias,\n",
    "                                       optimizer=optimizer,\n",
    "                                       optimizer_kwargs=optimizer_kwargs,\n",
    "                                       lr_scheduler=lr_scheduler,\n",
    "                                       lr_scheduler_kwargs=lr_scheduler_kwargs,\n",
    "                                       dataloader_kwargs=dataloader_kwargs,\n",
    "                                       **trainer_kwargs) \n",
    "\n",
    "        # Enforce correct patch_len, regardless of user input\n",
    "        patch_len = min(input_size + stride, patch_len)\n",
    "\n",
    "        c_out = self.loss.outputsize_multiplier\n",
    "\n",
    "        # Fixed hyperparameters\n",
    "        c_in = 1                  # Always univariate\n",
    "        padding_patch='end'       # Padding at the end\n",
    "        pretrain_head = False     # No pretrained head\n",
    "        norm = 'BatchNorm'        # Use BatchNorm (if batch_normalization is True)\n",
    "        pe = 'zeros'              # Initial zeros for positional encoding \n",
    "        d_k = None                # Key dimension\n",
    "        d_v = None                # Value dimension\n",
    "        store_attn = False        # Store attention weights\n",
    "        head_type = 'flatten'     # Head type\n",
    "        individual = False        # Separate heads for each time series\n",
    "        max_seq_len = 1024        # Not used\n",
    "        key_padding_mask = 'auto' # Not used\n",
    "        padding_var = None        # Not used\n",
    "        attn_mask = None          # Not used\n",
    "\n",
    "        self.model = PatchTST_backbone(c_in=c_in, c_out=c_out, input_size=input_size, h=h, patch_len=patch_len, stride=stride, \n",
    "                                max_seq_len=max_seq_len, n_layers=encoder_layers, hidden_size=hidden_size,\n",
    "                                n_heads=n_heads, d_k=d_k, d_v=d_v, linear_hidden_size=linear_hidden_size, norm=norm, attn_dropout=attn_dropout,\n",
    "                                dropout=dropout, act=activation, key_padding_mask=key_padding_mask, padding_var=padding_var, \n",
    "                                attn_mask=attn_mask, res_attention=res_attention, pre_norm=batch_normalization, store_attn=store_attn,\n",
    "                                pe=pe, learn_pe=learn_pos_embed, fc_dropout=fc_dropout, head_dropout=head_dropout, padding_patch = padding_patch,\n",
    "                                pretrain_head=pretrain_head, head_type=head_type, individual=individual, revin=revin, affine=revin_affine,\n",
    "                                subtract_last=revin_subtract_last)\n",
    "    \n",
    "    \n",
    "    def forward(self, windows_batch):  # x: [batch, input_size]\n",
    "\n",
    "        # Parse windows_batch\n",
    "        x    = windows_batch['insample_y']\n",
    "\n",
    "        x = x.permute(0,2,1)    # x: [Batch, 1, input_size]\n",
    "        x = self.model(x)\n",
    "        forecast = x.reshape(x.shape[0], self.h, -1) # x: [Batch, h, c_out]\n",
    "        \n",
    "        return forecast"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(PatchTST)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(PatchTST.fit, name='PatchTST.fit')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(PatchTST.predict, name='PatchTST.predict')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "# Unit tests for models\n",
    "logging.getLogger(\"pytorch_lightning\").setLevel(logging.ERROR)\n",
    "logging.getLogger(\"lightning_fabric\").setLevel(logging.ERROR)\n",
    "with warnings.catch_warnings():\n",
    "    warnings.simplefilter(\"ignore\")\n",
    "    check_model(PatchTST, [\"airpassengers\"])"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Usage example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| eval: false\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from neuralforecast import NeuralForecast\n",
    "from neuralforecast.models import PatchTST\n",
    "from neuralforecast.losses.pytorch import DistributionLoss\n",
    "from neuralforecast.utils import AirPassengersPanel, AirPassengersStatic, augment_calendar_df\n",
    "\n",
    "AirPassengersPanel, calendar_cols = augment_calendar_df(df=AirPassengersPanel, freq='M')\n",
    "\n",
    "Y_train_df = AirPassengersPanel[AirPassengersPanel.ds<AirPassengersPanel['ds'].values[-12]] # 132 train\n",
    "Y_test_df = AirPassengersPanel[AirPassengersPanel.ds>=AirPassengersPanel['ds'].values[-12]].reset_index(drop=True) # 12 test\n",
    "\n",
    "model = PatchTST(h=12,\n",
    "                 input_size=104,\n",
    "                 patch_len=24,\n",
    "                 stride=24,\n",
    "                 revin=False,\n",
    "                 hidden_size=16,\n",
    "                 n_heads=4,\n",
    "                 scaler_type='robust',\n",
    "                 loss=DistributionLoss(distribution='StudentT', level=[80, 90]),\n",
    "                 learning_rate=1e-3,\n",
    "                 max_steps=500,\n",
    "                 val_check_steps=50,\n",
    "                 early_stop_patience_steps=2)\n",
    "\n",
    "nf = NeuralForecast(\n",
    "    models=[model],\n",
    "    freq='ME'\n",
    ")\n",
    "nf.fit(df=Y_train_df, static_df=AirPassengersStatic, val_size=12)\n",
    "forecasts = nf.predict(futr_df=Y_test_df)\n",
    "\n",
    "Y_hat_df = forecasts.reset_index(drop=False).drop(columns=['unique_id','ds'])\n",
    "plot_df = pd.concat([Y_test_df, Y_hat_df], axis=1)\n",
    "plot_df = pd.concat([Y_train_df, plot_df])\n",
    "\n",
    "if model.loss.is_distribution_output:\n",
    "    plot_df = plot_df[plot_df.unique_id=='Airline1'].drop('unique_id', axis=1)\n",
    "    plt.plot(plot_df['ds'], plot_df['y'], c='black', label='True')\n",
    "    plt.plot(plot_df['ds'], plot_df['PatchTST-median'], c='blue', label='median')\n",
    "    plt.fill_between(x=plot_df['ds'][-12:], \n",
    "                    y1=plot_df['PatchTST-lo-90'][-12:].values, \n",
    "                    y2=plot_df['PatchTST-hi-90'][-12:].values,\n",
    "                    alpha=0.4, label='level 90')\n",
    "    plt.grid()\n",
    "    plt.legend()\n",
    "    plt.plot()\n",
    "else:\n",
    "    plot_df = plot_df[plot_df.unique_id=='Airline1'].drop('unique_id', axis=1)\n",
    "    plt.plot(plot_df['ds'], plot_df['y'], c='black', label='True')\n",
    "    plt.plot(plot_df['ds'], plot_df['PatchTST'], c='blue', label='Forecast')\n",
    "    plt.legend()\n",
    "    plt.grid()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
